{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "predict  (before training) 4 4.0\n",
      "\tgrad: 1.0 2.0 -2.0\n",
      "\tgrad: 2.0 4.0 -7.840000152587891\n",
      "\tgrad: 3.0 6.0 -16.228801727294922\n",
      "progress: 0 7.315943717956543\n",
      "\tgrad: 1.0 2.0 -1.478623867034912\n",
      "\tgrad: 2.0 4.0 -5.796205520629883\n",
      "\tgrad: 3.0 6.0 -11.998146057128906\n",
      "progress: 1 3.9987640380859375\n",
      "\tgrad: 1.0 2.0 -1.0931644439697266\n",
      "\tgrad: 2.0 4.0 -4.285204887390137\n",
      "\tgrad: 3.0 6.0 -8.870372772216797\n",
      "progress: 2 2.1856532096862793\n",
      "\tgrad: 1.0 2.0 -0.8081896305084229\n",
      "\tgrad: 2.0 4.0 -3.1681032180786133\n",
      "\tgrad: 3.0 6.0 -6.557973861694336\n",
      "progress: 3 1.1946394443511963\n",
      "\tgrad: 1.0 2.0 -0.5975041389465332\n",
      "\tgrad: 2.0 4.0 -2.3422164916992188\n",
      "\tgrad: 3.0 6.0 -4.848389625549316\n",
      "progress: 4 0.6529689431190491\n",
      "\tgrad: 1.0 2.0 -0.4417421817779541\n",
      "\tgrad: 2.0 4.0 -1.7316293716430664\n",
      "\tgrad: 3.0 6.0 -3.58447265625\n",
      "progress: 5 0.35690122842788696\n",
      "\tgrad: 1.0 2.0 -0.3265852928161621\n",
      "\tgrad: 2.0 4.0 -1.2802143096923828\n",
      "\tgrad: 3.0 6.0 -2.650045394897461\n",
      "progress: 6 0.195076122879982\n",
      "\tgrad: 1.0 2.0 -0.24144840240478516\n",
      "\tgrad: 2.0 4.0 -0.9464778900146484\n",
      "\tgrad: 3.0 6.0 -1.9592113494873047\n",
      "progress: 7 0.10662525147199631\n",
      "\tgrad: 1.0 2.0 -0.17850565910339355\n",
      "\tgrad: 2.0 4.0 -0.699742317199707\n",
      "\tgrad: 3.0 6.0 -1.4484672546386719\n",
      "progress: 8 0.0582793727517128\n",
      "\tgrad: 1.0 2.0 -0.1319713592529297\n",
      "\tgrad: 2.0 4.0 -0.5173273086547852\n",
      "\tgrad: 3.0 6.0 -1.070866584777832\n",
      "progress: 9 0.03185431286692619\n",
      "\tgrad: 1.0 2.0 -0.09756779670715332\n",
      "\tgrad: 2.0 4.0 -0.3824653625488281\n",
      "\tgrad: 3.0 6.0 -0.7917022705078125\n",
      "progress: 10 0.017410902306437492\n",
      "\tgrad: 1.0 2.0 -0.07213282585144043\n",
      "\tgrad: 2.0 4.0 -0.2827606201171875\n",
      "\tgrad: 3.0 6.0 -0.5853137969970703\n",
      "progress: 11 0.009516451507806778\n",
      "\tgrad: 1.0 2.0 -0.053328514099121094\n",
      "\tgrad: 2.0 4.0 -0.2090473175048828\n",
      "\tgrad: 3.0 6.0 -0.43272972106933594\n",
      "progress: 12 0.005201528314501047\n",
      "\tgrad: 1.0 2.0 -0.039426326751708984\n",
      "\tgrad: 2.0 4.0 -0.15455150604248047\n",
      "\tgrad: 3.0 6.0 -0.3199195861816406\n",
      "progress: 13 0.0028430151287466288\n",
      "\tgrad: 1.0 2.0 -0.029148340225219727\n",
      "\tgrad: 2.0 4.0 -0.11426162719726562\n",
      "\tgrad: 3.0 6.0 -0.23652076721191406\n",
      "progress: 14 0.0015539465239271522\n",
      "\tgrad: 1.0 2.0 -0.021549701690673828\n",
      "\tgrad: 2.0 4.0 -0.08447456359863281\n",
      "\tgrad: 3.0 6.0 -0.17486286163330078\n",
      "progress: 15 0.0008493617060594261\n",
      "\tgrad: 1.0 2.0 -0.01593184471130371\n",
      "\tgrad: 2.0 4.0 -0.062453269958496094\n",
      "\tgrad: 3.0 6.0 -0.12927818298339844\n",
      "progress: 16 0.00046424579340964556\n",
      "\tgrad: 1.0 2.0 -0.011778593063354492\n",
      "\tgrad: 2.0 4.0 -0.046172142028808594\n",
      "\tgrad: 3.0 6.0 -0.09557533264160156\n",
      "progress: 17 0.0002537401160225272\n",
      "\tgrad: 1.0 2.0 -0.00870823860168457\n",
      "\tgrad: 2.0 4.0 -0.03413581848144531\n",
      "\tgrad: 3.0 6.0 -0.07066154479980469\n",
      "progress: 18 0.00013869594840798527\n",
      "\tgrad: 1.0 2.0 -0.006437778472900391\n",
      "\tgrad: 2.0 4.0 -0.025236129760742188\n",
      "\tgrad: 3.0 6.0 -0.052239418029785156\n",
      "progress: 19 7.580435340059921e-05\n",
      "\tgrad: 1.0 2.0 -0.004759550094604492\n",
      "\tgrad: 2.0 4.0 -0.018657684326171875\n",
      "\tgrad: 3.0 6.0 -0.038620948791503906\n",
      "progress: 20 4.143271507928148e-05\n",
      "\tgrad: 1.0 2.0 -0.003518819808959961\n",
      "\tgrad: 2.0 4.0 -0.0137939453125\n",
      "\tgrad: 3.0 6.0 -0.028553009033203125\n",
      "progress: 21 2.264650902361609e-05\n",
      "\tgrad: 1.0 2.0 -0.00260162353515625\n",
      "\tgrad: 2.0 4.0 -0.010198593139648438\n",
      "\tgrad: 3.0 6.0 -0.021108627319335938\n",
      "progress: 22 1.2377059647405986e-05\n",
      "\tgrad: 1.0 2.0 -0.0019233226776123047\n",
      "\tgrad: 2.0 4.0 -0.0075397491455078125\n",
      "\tgrad: 3.0 6.0 -0.0156097412109375\n",
      "progress: 23 6.768445018678904e-06\n",
      "\tgrad: 1.0 2.0 -0.0014221668243408203\n",
      "\tgrad: 2.0 4.0 -0.0055751800537109375\n",
      "\tgrad: 3.0 6.0 -0.011541366577148438\n",
      "progress: 24 3.7000872907810844e-06\n",
      "\tgrad: 1.0 2.0 -0.0010514259338378906\n",
      "\tgrad: 2.0 4.0 -0.0041217803955078125\n",
      "\tgrad: 3.0 6.0 -0.008531570434570312\n",
      "progress: 25 2.021880391112063e-06\n",
      "\tgrad: 1.0 2.0 -0.0007772445678710938\n",
      "\tgrad: 2.0 4.0 -0.0030469894409179688\n",
      "\tgrad: 3.0 6.0 -0.006305694580078125\n",
      "progress: 26 1.1044940038118511e-06\n",
      "\tgrad: 1.0 2.0 -0.0005745887756347656\n",
      "\tgrad: 2.0 4.0 -0.0022525787353515625\n",
      "\tgrad: 3.0 6.0 -0.0046634674072265625\n",
      "progress: 27 6.041091182851233e-07\n",
      "\tgrad: 1.0 2.0 -0.0004248619079589844\n",
      "\tgrad: 2.0 4.0 -0.0016651153564453125\n",
      "\tgrad: 3.0 6.0 -0.003444671630859375\n",
      "progress: 28 3.296045179013163e-07\n",
      "\tgrad: 1.0 2.0 -0.0003139972686767578\n",
      "\tgrad: 2.0 4.0 -0.0012311935424804688\n",
      "\tgrad: 3.0 6.0 -0.0025491714477539062\n",
      "progress: 29 1.805076408345485e-07\n",
      "\tgrad: 1.0 2.0 -0.00023221969604492188\n",
      "\tgrad: 2.0 4.0 -0.0009107589721679688\n",
      "\tgrad: 3.0 6.0 -0.0018854141235351562\n",
      "progress: 30 9.874406714516226e-08\n",
      "\tgrad: 1.0 2.0 -0.00017189979553222656\n",
      "\tgrad: 2.0 4.0 -0.0006742477416992188\n",
      "\tgrad: 3.0 6.0 -0.00139617919921875\n",
      "progress: 31 5.4147676564753056e-08\n",
      "\tgrad: 1.0 2.0 -0.0001270771026611328\n",
      "\tgrad: 2.0 4.0 -0.0004978179931640625\n",
      "\tgrad: 3.0 6.0 -0.00102996826171875\n",
      "progress: 32 2.9467628337442875e-08\n",
      "\tgrad: 1.0 2.0 -9.393692016601562e-05\n",
      "\tgrad: 2.0 4.0 -0.0003681182861328125\n",
      "\tgrad: 3.0 6.0 -0.0007610321044921875\n",
      "progress: 33 1.6088051779661328e-08\n",
      "\tgrad: 1.0 2.0 -6.937980651855469e-05\n",
      "\tgrad: 2.0 4.0 -0.00027179718017578125\n",
      "\tgrad: 3.0 6.0 -0.000560760498046875\n",
      "progress: 34 8.734787115827203e-09\n",
      "\tgrad: 1.0 2.0 -5.125999450683594e-05\n",
      "\tgrad: 2.0 4.0 -0.00020122528076171875\n",
      "\tgrad: 3.0 6.0 -0.0004177093505859375\n",
      "progress: 35 4.8466972657479346e-09\n",
      "\tgrad: 1.0 2.0 -3.790855407714844e-05\n",
      "\tgrad: 2.0 4.0 -0.000148773193359375\n",
      "\tgrad: 3.0 6.0 -0.000308990478515625\n",
      "progress: 36 2.6520865503698587e-09\n",
      "\tgrad: 1.0 2.0 -2.8133392333984375e-05\n",
      "\tgrad: 2.0 4.0 -0.000110626220703125\n",
      "\tgrad: 3.0 6.0 -0.0002288818359375\n",
      "progress: 37 1.4551915228366852e-09\n",
      "\tgrad: 1.0 2.0 -2.09808349609375e-05\n",
      "\tgrad: 2.0 4.0 -8.20159912109375e-05\n",
      "\tgrad: 3.0 6.0 -0.00016880035400390625\n",
      "progress: 38 7.914877642178908e-10\n",
      "\tgrad: 1.0 2.0 -1.5497207641601562e-05\n",
      "\tgrad: 2.0 4.0 -6.103515625e-05\n",
      "\tgrad: 3.0 6.0 -0.000125885009765625\n",
      "progress: 39 4.4019543565809727e-10\n",
      "\tgrad: 1.0 2.0 -1.1444091796875e-05\n",
      "\tgrad: 2.0 4.0 -4.482269287109375e-05\n",
      "\tgrad: 3.0 6.0 -9.1552734375e-05\n",
      "progress: 40 2.3283064365386963e-10\n",
      "\tgrad: 1.0 2.0 -8.344650268554688e-06\n",
      "\tgrad: 2.0 4.0 -3.24249267578125e-05\n",
      "\tgrad: 3.0 6.0 -6.580352783203125e-05\n",
      "progress: 41 1.2028067430946976e-10\n",
      "\tgrad: 1.0 2.0 -5.9604644775390625e-06\n",
      "\tgrad: 2.0 4.0 -2.288818359375e-05\n",
      "\tgrad: 3.0 6.0 -4.57763671875e-05\n",
      "progress: 42 5.820766091346741e-11\n",
      "\tgrad: 1.0 2.0 -4.291534423828125e-06\n",
      "\tgrad: 2.0 4.0 -1.71661376953125e-05\n",
      "\tgrad: 3.0 6.0 -3.719329833984375e-05\n",
      "progress: 43 3.842615114990622e-11\n",
      "\tgrad: 1.0 2.0 -3.337860107421875e-06\n",
      "\tgrad: 2.0 4.0 -1.33514404296875e-05\n",
      "\tgrad: 3.0 6.0 -2.86102294921875e-05\n",
      "progress: 44 2.2737367544323206e-11\n",
      "\tgrad: 1.0 2.0 -2.6226043701171875e-06\n",
      "\tgrad: 2.0 4.0 -1.049041748046875e-05\n",
      "\tgrad: 3.0 6.0 -2.288818359375e-05\n",
      "progress: 45 1.4551915228366852e-11\n",
      "\tgrad: 1.0 2.0 -1.9073486328125e-06\n",
      "\tgrad: 2.0 4.0 -7.62939453125e-06\n",
      "\tgrad: 3.0 6.0 -1.430511474609375e-05\n",
      "progress: 46 5.6843418860808015e-12\n",
      "\tgrad: 1.0 2.0 -1.430511474609375e-06\n",
      "\tgrad: 2.0 4.0 -5.7220458984375e-06\n",
      "\tgrad: 3.0 6.0 -1.1444091796875e-05\n",
      "progress: 47 3.637978807091713e-12\n",
      "\tgrad: 1.0 2.0 -1.1920928955078125e-06\n",
      "\tgrad: 2.0 4.0 -4.76837158203125e-06\n",
      "\tgrad: 3.0 6.0 -1.1444091796875e-05\n",
      "progress: 48 3.637978807091713e-12\n",
      "\tgrad: 1.0 2.0 -9.5367431640625e-07\n",
      "\tgrad: 2.0 4.0 -3.814697265625e-06\n",
      "\tgrad: 3.0 6.0 -8.58306884765625e-06\n",
      "progress: 49 2.0463630789890885e-12\n",
      "\tgrad: 1.0 2.0 -7.152557373046875e-07\n",
      "\tgrad: 2.0 4.0 -2.86102294921875e-06\n",
      "\tgrad: 3.0 6.0 -5.7220458984375e-06\n",
      "progress: 50 9.094947017729282e-13\n",
      "\tgrad: 1.0 2.0 -7.152557373046875e-07\n",
      "\tgrad: 2.0 4.0 -2.86102294921875e-06\n",
      "\tgrad: 3.0 6.0 -5.7220458984375e-06\n",
      "progress: 51 9.094947017729282e-13\n",
      "\tgrad: 1.0 2.0 -7.152557373046875e-07\n",
      "\tgrad: 2.0 4.0 -2.86102294921875e-06\n",
      "\tgrad: 3.0 6.0 -5.7220458984375e-06\n",
      "progress: 52 9.094947017729282e-13\n",
      "\tgrad: 1.0 2.0 -7.152557373046875e-07\n",
      "\tgrad: 2.0 4.0 -2.86102294921875e-06\n",
      "\tgrad: 3.0 6.0 -5.7220458984375e-06\n",
      "progress: 53 9.094947017729282e-13\n",
      "\tgrad: 1.0 2.0 -7.152557373046875e-07\n",
      "\tgrad: 2.0 4.0 -2.86102294921875e-06\n",
      "\tgrad: 3.0 6.0 -5.7220458984375e-06\n",
      "progress: 54 9.094947017729282e-13\n",
      "\tgrad: 1.0 2.0 -7.152557373046875e-07\n",
      "\tgrad: 2.0 4.0 -2.86102294921875e-06\n",
      "\tgrad: 3.0 6.0 -5.7220458984375e-06\n",
      "progress: 55 9.094947017729282e-13\n",
      "\tgrad: 1.0 2.0 -7.152557373046875e-07\n",
      "\tgrad: 2.0 4.0 -2.86102294921875e-06\n",
      "\tgrad: 3.0 6.0 -5.7220458984375e-06\n",
      "progress: 56 9.094947017729282e-13\n",
      "\tgrad: 1.0 2.0 -7.152557373046875e-07\n",
      "\tgrad: 2.0 4.0 -2.86102294921875e-06\n",
      "\tgrad: 3.0 6.0 -5.7220458984375e-06\n",
      "progress: 57 9.094947017729282e-13\n",
      "\tgrad: 1.0 2.0 -7.152557373046875e-07\n",
      "\tgrad: 2.0 4.0 -2.86102294921875e-06\n",
      "\tgrad: 3.0 6.0 -5.7220458984375e-06\n",
      "progress: 58 9.094947017729282e-13\n",
      "\tgrad: 1.0 2.0 -7.152557373046875e-07\n",
      "\tgrad: 2.0 4.0 -2.86102294921875e-06\n",
      "\tgrad: 3.0 6.0 -5.7220458984375e-06\n",
      "progress: 59 9.094947017729282e-13\n",
      "\tgrad: 1.0 2.0 -7.152557373046875e-07\n",
      "\tgrad: 2.0 4.0 -2.86102294921875e-06\n",
      "\tgrad: 3.0 6.0 -5.7220458984375e-06\n",
      "progress: 60 9.094947017729282e-13\n",
      "\tgrad: 1.0 2.0 -7.152557373046875e-07\n",
      "\tgrad: 2.0 4.0 -2.86102294921875e-06\n",
      "\tgrad: 3.0 6.0 -5.7220458984375e-06\n",
      "progress: 61 9.094947017729282e-13\n",
      "\tgrad: 1.0 2.0 -7.152557373046875e-07\n",
      "\tgrad: 2.0 4.0 -2.86102294921875e-06\n",
      "\tgrad: 3.0 6.0 -5.7220458984375e-06\n",
      "progress: 62 9.094947017729282e-13\n",
      "\tgrad: 1.0 2.0 -7.152557373046875e-07\n",
      "\tgrad: 2.0 4.0 -2.86102294921875e-06\n",
      "\tgrad: 3.0 6.0 -5.7220458984375e-06\n",
      "progress: 63 9.094947017729282e-13\n",
      "\tgrad: 1.0 2.0 -7.152557373046875e-07\n",
      "\tgrad: 2.0 4.0 -2.86102294921875e-06\n",
      "\tgrad: 3.0 6.0 -5.7220458984375e-06\n",
      "progress: 64 9.094947017729282e-13\n",
      "\tgrad: 1.0 2.0 -7.152557373046875e-07\n",
      "\tgrad: 2.0 4.0 -2.86102294921875e-06\n",
      "\tgrad: 3.0 6.0 -5.7220458984375e-06\n",
      "progress: 65 9.094947017729282e-13\n",
      "\tgrad: 1.0 2.0 -7.152557373046875e-07\n",
      "\tgrad: 2.0 4.0 -2.86102294921875e-06\n",
      "\tgrad: 3.0 6.0 -5.7220458984375e-06\n",
      "progress: 66 9.094947017729282e-13\n",
      "\tgrad: 1.0 2.0 -7.152557373046875e-07\n",
      "\tgrad: 2.0 4.0 -2.86102294921875e-06\n",
      "\tgrad: 3.0 6.0 -5.7220458984375e-06\n",
      "progress: 67 9.094947017729282e-13\n",
      "\tgrad: 1.0 2.0 -7.152557373046875e-07\n",
      "\tgrad: 2.0 4.0 -2.86102294921875e-06\n",
      "\tgrad: 3.0 6.0 -5.7220458984375e-06\n",
      "progress: 68 9.094947017729282e-13\n",
      "\tgrad: 1.0 2.0 -7.152557373046875e-07\n",
      "\tgrad: 2.0 4.0 -2.86102294921875e-06\n",
      "\tgrad: 3.0 6.0 -5.7220458984375e-06\n",
      "progress: 69 9.094947017729282e-13\n",
      "\tgrad: 1.0 2.0 -7.152557373046875e-07\n",
      "\tgrad: 2.0 4.0 -2.86102294921875e-06\n",
      "\tgrad: 3.0 6.0 -5.7220458984375e-06\n",
      "progress: 70 9.094947017729282e-13\n",
      "\tgrad: 1.0 2.0 -7.152557373046875e-07\n",
      "\tgrad: 2.0 4.0 -2.86102294921875e-06\n",
      "\tgrad: 3.0 6.0 -5.7220458984375e-06\n",
      "progress: 71 9.094947017729282e-13\n",
      "\tgrad: 1.0 2.0 -7.152557373046875e-07\n",
      "\tgrad: 2.0 4.0 -2.86102294921875e-06\n",
      "\tgrad: 3.0 6.0 -5.7220458984375e-06\n",
      "progress: 72 9.094947017729282e-13\n",
      "\tgrad: 1.0 2.0 -7.152557373046875e-07\n",
      "\tgrad: 2.0 4.0 -2.86102294921875e-06\n",
      "\tgrad: 3.0 6.0 -5.7220458984375e-06\n",
      "progress: 73 9.094947017729282e-13\n",
      "\tgrad: 1.0 2.0 -7.152557373046875e-07\n",
      "\tgrad: 2.0 4.0 -2.86102294921875e-06\n",
      "\tgrad: 3.0 6.0 -5.7220458984375e-06\n",
      "progress: 74 9.094947017729282e-13\n",
      "\tgrad: 1.0 2.0 -7.152557373046875e-07\n",
      "\tgrad: 2.0 4.0 -2.86102294921875e-06\n",
      "\tgrad: 3.0 6.0 -5.7220458984375e-06\n",
      "progress: 75 9.094947017729282e-13\n",
      "\tgrad: 1.0 2.0 -7.152557373046875e-07\n",
      "\tgrad: 2.0 4.0 -2.86102294921875e-06\n",
      "\tgrad: 3.0 6.0 -5.7220458984375e-06\n",
      "progress: 76 9.094947017729282e-13\n",
      "\tgrad: 1.0 2.0 -7.152557373046875e-07\n",
      "\tgrad: 2.0 4.0 -2.86102294921875e-06\n",
      "\tgrad: 3.0 6.0 -5.7220458984375e-06\n",
      "progress: 77 9.094947017729282e-13\n",
      "\tgrad: 1.0 2.0 -7.152557373046875e-07\n",
      "\tgrad: 2.0 4.0 -2.86102294921875e-06\n",
      "\tgrad: 3.0 6.0 -5.7220458984375e-06\n",
      "progress: 78 9.094947017729282e-13\n",
      "\tgrad: 1.0 2.0 -7.152557373046875e-07\n",
      "\tgrad: 2.0 4.0 -2.86102294921875e-06\n",
      "\tgrad: 3.0 6.0 -5.7220458984375e-06\n",
      "progress: 79 9.094947017729282e-13\n",
      "\tgrad: 1.0 2.0 -7.152557373046875e-07\n",
      "\tgrad: 2.0 4.0 -2.86102294921875e-06\n",
      "\tgrad: 3.0 6.0 -5.7220458984375e-06\n",
      "progress: 80 9.094947017729282e-13\n",
      "\tgrad: 1.0 2.0 -7.152557373046875e-07\n",
      "\tgrad: 2.0 4.0 -2.86102294921875e-06\n",
      "\tgrad: 3.0 6.0 -5.7220458984375e-06\n",
      "progress: 81 9.094947017729282e-13\n",
      "\tgrad: 1.0 2.0 -7.152557373046875e-07\n",
      "\tgrad: 2.0 4.0 -2.86102294921875e-06\n",
      "\tgrad: 3.0 6.0 -5.7220458984375e-06\n",
      "progress: 82 9.094947017729282e-13\n",
      "\tgrad: 1.0 2.0 -7.152557373046875e-07\n",
      "\tgrad: 2.0 4.0 -2.86102294921875e-06\n",
      "\tgrad: 3.0 6.0 -5.7220458984375e-06\n",
      "progress: 83 9.094947017729282e-13\n",
      "\tgrad: 1.0 2.0 -7.152557373046875e-07\n",
      "\tgrad: 2.0 4.0 -2.86102294921875e-06\n",
      "\tgrad: 3.0 6.0 -5.7220458984375e-06\n",
      "progress: 84 9.094947017729282e-13\n",
      "\tgrad: 1.0 2.0 -7.152557373046875e-07\n",
      "\tgrad: 2.0 4.0 -2.86102294921875e-06\n",
      "\tgrad: 3.0 6.0 -5.7220458984375e-06\n",
      "progress: 85 9.094947017729282e-13\n",
      "\tgrad: 1.0 2.0 -7.152557373046875e-07\n",
      "\tgrad: 2.0 4.0 -2.86102294921875e-06\n",
      "\tgrad: 3.0 6.0 -5.7220458984375e-06\n",
      "progress: 86 9.094947017729282e-13\n",
      "\tgrad: 1.0 2.0 -7.152557373046875e-07\n",
      "\tgrad: 2.0 4.0 -2.86102294921875e-06\n",
      "\tgrad: 3.0 6.0 -5.7220458984375e-06\n",
      "progress: 87 9.094947017729282e-13\n",
      "\tgrad: 1.0 2.0 -7.152557373046875e-07\n",
      "\tgrad: 2.0 4.0 -2.86102294921875e-06\n",
      "\tgrad: 3.0 6.0 -5.7220458984375e-06\n",
      "progress: 88 9.094947017729282e-13\n",
      "\tgrad: 1.0 2.0 -7.152557373046875e-07\n",
      "\tgrad: 2.0 4.0 -2.86102294921875e-06\n",
      "\tgrad: 3.0 6.0 -5.7220458984375e-06\n",
      "progress: 89 9.094947017729282e-13\n",
      "\tgrad: 1.0 2.0 -7.152557373046875e-07\n",
      "\tgrad: 2.0 4.0 -2.86102294921875e-06\n",
      "\tgrad: 3.0 6.0 -5.7220458984375e-06\n",
      "progress: 90 9.094947017729282e-13\n",
      "\tgrad: 1.0 2.0 -7.152557373046875e-07\n",
      "\tgrad: 2.0 4.0 -2.86102294921875e-06\n",
      "\tgrad: 3.0 6.0 -5.7220458984375e-06\n",
      "progress: 91 9.094947017729282e-13\n",
      "\tgrad: 1.0 2.0 -7.152557373046875e-07\n",
      "\tgrad: 2.0 4.0 -2.86102294921875e-06\n",
      "\tgrad: 3.0 6.0 -5.7220458984375e-06\n",
      "progress: 92 9.094947017729282e-13\n",
      "\tgrad: 1.0 2.0 -7.152557373046875e-07\n",
      "\tgrad: 2.0 4.0 -2.86102294921875e-06\n",
      "\tgrad: 3.0 6.0 -5.7220458984375e-06\n",
      "progress: 93 9.094947017729282e-13\n",
      "\tgrad: 1.0 2.0 -7.152557373046875e-07\n",
      "\tgrad: 2.0 4.0 -2.86102294921875e-06\n",
      "\tgrad: 3.0 6.0 -5.7220458984375e-06\n",
      "progress: 94 9.094947017729282e-13\n",
      "\tgrad: 1.0 2.0 -7.152557373046875e-07\n",
      "\tgrad: 2.0 4.0 -2.86102294921875e-06\n",
      "\tgrad: 3.0 6.0 -5.7220458984375e-06\n",
      "progress: 95 9.094947017729282e-13\n",
      "\tgrad: 1.0 2.0 -7.152557373046875e-07\n",
      "\tgrad: 2.0 4.0 -2.86102294921875e-06\n",
      "\tgrad: 3.0 6.0 -5.7220458984375e-06\n",
      "progress: 96 9.094947017729282e-13\n",
      "\tgrad: 1.0 2.0 -7.152557373046875e-07\n",
      "\tgrad: 2.0 4.0 -2.86102294921875e-06\n",
      "\tgrad: 3.0 6.0 -5.7220458984375e-06\n",
      "progress: 97 9.094947017729282e-13\n",
      "\tgrad: 1.0 2.0 -7.152557373046875e-07\n",
      "\tgrad: 2.0 4.0 -2.86102294921875e-06\n",
      "\tgrad: 3.0 6.0 -5.7220458984375e-06\n",
      "progress: 98 9.094947017729282e-13\n",
      "\tgrad: 1.0 2.0 -7.152557373046875e-07\n",
      "\tgrad: 2.0 4.0 -2.86102294921875e-06\n",
      "\tgrad: 3.0 6.0 -5.7220458984375e-06\n",
      "progress: 99 9.094947017729282e-13\n",
      "predict (after training) 4 7.999998569488525\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\verfallen\\.conda\\envs\\pytorch\\lib\\site-packages\\torch\\autograd\\__init__.py:173: UserWarning: CUDA initialization: CUDA driver initialization failed, you might not have a CUDA gpu. (Triggered internally at  C:\\cb\\pytorch_1000000000000\\work\\c10\\cuda\\CUDAFunctions.cpp:112.)\n",
      "  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "x_data = [1.0, 2.0, 3.0]\n",
    "y_data = [2.0, 4.0, 6.0]\n",
    "\n",
    "#赋予tensor中的data\n",
    "w = torch.Tensor([1.0])\n",
    "#设定需要计算梯度grad\n",
    "w.requires_grad = True\n",
    "\n",
    "#模型y=x*w 建立计算图\n",
    "def forward(x):\n",
    "    '''\n",
    "    w为Tensor类型\n",
    "    x强制转换为Tensor类型\n",
    "    通过这样的方式建立计算图\n",
    "    '''\n",
    "    return x * w\n",
    "\n",
    "def loss(x, y):\n",
    "    y_pred = forward(x)\n",
    "    return (y_pred - y) ** 2\n",
    "\n",
    "print (\"predict  (before training)\", 4, forward(4).item())\n",
    "\n",
    "for epoch in range(100):\n",
    "    for x,y in zip(x_data,y_data):\n",
    "        #创建新的计算图\n",
    "        l = loss(x,y)\n",
    "        #进行反馈计算，此时才开始求梯度，此后计算图进行释放\n",
    "        l.backward()\n",
    "        #grad.item()取grad中的值变成标量\n",
    "        print('\\tgrad:',x, y, w.grad.item())\n",
    "        #单纯的数值计算要利用data，而不能用张量，否则会在内部创建新的计算图\n",
    "        w.data = w.data - 0.01 * w.grad.data\n",
    "        #把权重梯度里的数据清零\n",
    "        w.grad.data.zero_()\n",
    "    print(\"progress:\",epoch, l.item())\n",
    "\n",
    "print(\"predict (after training)\", 4, forward(4).item())"
   ]
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "9a28f02bfe67a97f59d08385be9e40ae75d301e4305d8ead555fed38ae9c259e"
  },
  "kernelspec": {
   "display_name": "Python 3.9.12 ('pytorch')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
